{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Converting Variational Models to TorchScript\n",
    "\n",
    "The purpose of this notebook is to demonstrate how to convert a variational GPyTorch model to a ScriptModule that can e.g. be exported to LibTorch.\n",
    "\n",
    "In general the process is quite similar to standard torch models, where we will trace them using `torch.jit.trace`. However there are two key differences:\n",
    "\n",
    "1. The first time you make predictions with a GPyTorch model (exact or approximate), we cache certain computations. These computations can't be traced, but the results of them can be. Therefore, we'll need to pass data through the untraced model once, and then trace the model.\n",
    "1. You can't trace models that return Distribution objects. Therefore, we'll write a simple wrapper than unpacks the MultivariateNormal that our GPs return in to just a mean and variance tensor.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Data and Define Model\n",
    "\n",
    "In this tutorial, we'll be tracing an SVGP model trained for just 10 epochs on the `elevators` UCI dataset. The next two cells are copied directly from our variational tutorial, and download the data and define the variational GP model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import urllib.request\n",
    "import os\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "\n",
    "\n",
    "# this is for running the notebook in our testing framework\n",
    "smoke_test = ('CI' in os.environ)\n",
    "\n",
    "if not smoke_test and not os.path.isfile('../elevators.mat'):\n",
    "    print('Downloading \\'elevators\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')\n",
    "\n",
    "\n",
    "if smoke_test:  # this is for running the notebook in our testing framework\n",
    "    X, y = torch.randn(1000, 18), torch.randn(1000)\n",
    "else:\n",
    "    data = torch.Tensor(loadmat('../elevators.mat')['data'])\n",
    "    X = data[:, :-1]\n",
    "    X = X - X.min(0)[0]\n",
    "    X = 2 * (X / X.max(0)[0]) - 1\n",
    "    y = data[:, -1]\n",
    "\n",
    "\n",
    "train_n = int(floor(0.8 * len(X)))\n",
    "train_x = X[:train_n, :].contiguous()\n",
    "train_y = y[:train_n].contiguous()\n",
    "\n",
    "test_x = X[train_n:, :].contiguous()\n",
    "test_y = y[train_n:].contiguous()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gpytorch\n",
    "\n",
    "from gpytorch.models import ApproximateGP\n",
    "from gpytorch.variational import CholeskyVariationalDistribution\n",
    "from gpytorch.variational import VariationalStrategy\n",
    "\n",
    "class GPModel(ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))\n",
    "        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)\n",
    "        super(GPModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=18))\n",
    "        \n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "inducing_points = torch.randn(500, 18)\n",
    "model = GPModel(inducing_points=inducing_points)\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    likelihood = likelihood.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a Trained Model\n",
    "\n",
    "To keep things simple for this notebook, we won't be training here. Instead, we'll be loading the parameters for a pre-trained model on elevators that we trained in the SVGP example notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if torch.cuda.is_available():\n",
    "    model_state_dict, likelihood_state_dict = torch.load('svgp_elevators.pt')\n",
    "else:\n",
    "    model_state_dict, likelihood_state_dict = torch.load('svgp_elevators.pt', map_location='cpu')\n",
    "model.load_state_dict(model_state_dict)\n",
    "likelihood.load_state_dict(likelihood_state_dict)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a Wrapper\n",
    "\n",
    "Instead of directly tracing the GP, we'll need to trace a PyTorch Module that returns tensors. In the next cell, we define a wrapper that calls a GP and then unpacks the resulting Distribution in to a mean and variance.\n",
    "\n",
    "You could also return the full `covariance_matrix` if you wanted that rather than the variance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MeanVarModelWrapper(torch.nn.Module):\n",
    "    def __init__(self, gp):\n",
    "        super().__init__()\n",
    "        self.gp = gp\n",
    "    \n",
    "    def forward(self, x):\n",
    "        output_dist = self.gp(x)\n",
    "        return output_dist.mean, output_dist.variance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trace the Model\n",
    "\n",
    "In the next cell, we trace the model as normal, with the exception that we first pass data through the wrapped model so that GPyTorch can compute all of the things it needs to cache that can't be traced. Mostly, this just involves some complex linear algebra operations for variational GPs.\n",
    "\n",
    "Additionally, we'll need to run with the `gpytorch.settings.trace_mode` setting enabled, because PyTorch can't trace custom autograd Functions. Note that this results in some inefficiencies, e.g. for variational models we will always compute the full predictive posterior covariance in the traced model. This is not so bad, because we can always just process minibatches of data.\n",
    "\n",
    "**Note:** You'll get a lot of warnings from the tracer. That's fine. GPyTorch models are pretty large graphs, and include things like `.item()` calls that you wouldn't normally encounter in a basic neural network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "wrapped_model = MeanVarModelWrapper(model)\n",
    "\n",
    "with torch.no_grad(), gpytorch.settings.trace_mode():\n",
    "    fake_input = test_x[:1024, :]\n",
    "    pred = wrapped_model(fake_input)  # Compute caches\n",
    "    traced_model = torch.jit.trace(wrapped_model, fake_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0756, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "tensor(0.0756, device='cuda:0', grad_fn=<MeanBackward0>)\n"
     ]
    }
   ],
   "source": [
    "## Compute Errors on a minibatch\n",
    "\n",
    "mean1 = wrapped_model(test_x[:1024, :])[0]\n",
    "mean2 = traced_model(test_x[:1024, :])[0]\n",
    "\n",
    "print(torch.mean(torch.abs(mean1 - test_y[:1024])))\n",
    "print(torch.mean(torch.abs(mean2 - test_y[:1024])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_model.save('traced_model.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
